import torch

x = torch.tensor([0.51,0.7,0.3])
y = torch.tensor([1,1,0])
# x = torch.round(x)
# conv = torch.nn.Conv2d(3,31,(3,3))


p = torch.eq(torch.round(x), y)
# p = torch.sum(torch.max(x,0)[0].eq(torch.max(y,0)[0]))
# p = torch.eq(x,y)
print(p)